{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.5\n",
    "试编程实现基于对率回归进行划分选择的决策树算法，并为表4.3中数据生成一棵决策树。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from ch4 import TreeGenerate_logistic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "dataset = pd.read_csv('data/table_4_3_watermelon_3_0_num.csv')\n",
    "dataset = dataset.drop('Idx',axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class Node(object):\n",
    "    def __init__(self, att=None, child={}, label=None, weight = None):\n",
    "        self.att = att\n",
    "        self.label = label \n",
    "        self.child = child\n",
    "        self.w = weight\n",
    "        \n",
    "def mode(l):      \n",
    "    #统计list中各个数值出现的次数  \n",
    "    count_dict=dict()\n",
    "    for i in l:  \n",
    "        if i in count_dict.keys():  \n",
    "            count_dict[i]+=1 \n",
    "        else:  \n",
    "            count_dict[i]=1\n",
    "    #求出现次数的最大值  \n",
    "    max_appear=0  \n",
    "    for v in count_dict.values():  \n",
    "        if v>max_appear:  \n",
    "            max_appear=v\n",
    "    mode_list=[]\n",
    "    for k,v in count_dict.items():  \n",
    "        if v==max_appear:  \n",
    "            mode_list.append(k)\n",
    "    return mode_list\n",
    "\n",
    "def get_one_label(label):\n",
    "    mlabel = mode(label)\n",
    "    if len(mlabel) > 1:\n",
    "        mlabel = mlabel[np.random.randint(len(mlabel))]\n",
    "    else:\n",
    "        mlabel = mlabel[0]\n",
    "    return mlabel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def TreeGenerate_logistic(dataset):\n",
    "    att = list(dataset.columns[:-1])\n",
    "#     要有初始化值(None,{},None)。\n",
    "    rn = Node(None, {}, None)\n",
    "#     if D中样本全属于同一类别\n",
    "    if len(np.unique(dataset.ix[:,-1])) == 1:\n",
    "        rn.label = np.unique(dataset.ix[:,-1])[0]\n",
    "        return rn\n",
    "#     if D中样本在A上取值相同\n",
    "    same = 0\n",
    "    for a in att:\n",
    "        if len(np.unique(dataset.ix[:,a]))>1:\n",
    "            same = 1\n",
    "            break\n",
    "    if same==0:\n",
    "        mlabel = mode(dataset.ix[:,-1])\n",
    "        rn.label = mlabel\n",
    "        return rn\n",
    "#     构建logistic回归\n",
    "    rn.label = get_one_label(dataset.ix[:,-1])    \n",
    "    log_model = LogisticRegression() \n",
    "    log_model.fit(dataset.iloc[:,:-1], dataset.iloc[:,-1]) \n",
    "    predict_label = log_model.predict(dataset.iloc[:,:-1]) \n",
    "    rn.w = np.hstack([log_model.coef_[0], log_model.intercept_])\n",
    "    left_value = 'y=0'\n",
    "    right_value = 'y=1'\n",
    "    if dataset.ix[predict_label<=0,:].empty or len(np.unique(predict_label))==1:\n",
    "        rn.child[left_value] = Node(None, {}, None)\n",
    "        rn.child[left_value].label = 0\n",
    "        \n",
    "    else:\n",
    "        rn.child[left_value] = TreeGenerate_logistic(dataset.ix[predict_label<=0,:]) \n",
    "    if dataset.ix[predict_label>0,:].empty or len(np.unique(predict_label))==1:\n",
    "        rn.child[right_value] = Node(None, {}, None)\n",
    "        rn.child[right_value].label = 1\n",
    "    else:\n",
    "        rn.child[right_value] = TreeGenerate_logistic(dataset.ix[predict_label>0,:]) \n",
    "    return rn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "rn = TreeGenerate_logistic(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.10205268, -0.37317245, -0.4981436 , -1.09623118, -0.45255551,\n",
       "        0.18352238,  0.43046137,  0.2955575 ,  0.69335734])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rn.w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'y=0': <__main__.Node at 0x15c0056fd0>,\n",
       " 'y=1': <__main__.Node at 0x15c004b5f8>}"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rn.child"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.1654895 , -0.07651756, -0.65796895, -0.45010717, -0.2973238 ,\n",
       "        0.45282301, -0.2186481 , -0.03678544, -0.30481051])"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rn.child['y=0'].w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.21762188, -0.08658134,  0.31408021,  0.        , -0.08658134,\n",
       "       -0.32902918,  0.58208979,  0.17547878,  0.80344473])"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rn.child['y=1'].w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{}"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rn.child['y=0'].child['y=0'].child"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
